package com.yanmushi.aws.sqs;

import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.amazonaws.services.sqs.model.ReceiveMessageResult;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import com.amazonaws.util.Base32;
import com.amazonaws.util.Md5Utils;
import com.yanmushi.aws.sqs.impl.DropIllegalMessageHandlerImpl;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.CollectionUtils;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author yinlei
 * @since 2021/1/26
 */
public class AwsSqsService implements InitializingBean {

    private static final Logger log = LoggerFactory.getLogger(AwsSqsService.class);

    private AmazonSQS amazonSQS;
    private MessageObjectConverter messageObjectConverter;
    private String queueName;
    private int timeout = 60;
    private boolean errorContinue = false;
    private int updateVisibilityInterval = 5;
    private IllegalMessageHandler illegalMessageHandler;

    private boolean useAttr = false;
    private String queueUrl;
    private ConcurrentHashMap<String, ObjTime> objTimeMap;

    /**
     * 发送消息
     * @param obj 实体对象
     */
    public void send(SqsObject obj) {
        Message message = messageObjectConverter.obj2Message(obj);

        SendMessageRequest req = new SendMessageRequest();
        req.withQueueUrl(queueUrl).withMessageBody(message.getBody());

        if (useAttr) {
            req.withMessageAttributes(message.getMessageAttributes());
        }

        amazonSQS.sendMessage(req);

        messageObjectConverter.log("send", obj);
    }

    /**
     * 接收消息
     * @return 实体对象
     */
    public SqsObject receive() {

        ReceiveMessageRequest request = new ReceiveMessageRequest();
        request.withMaxNumberOfMessages(1)
                .withVisibilityTimeout(timeout)
                .withQueueUrl(queueUrl);
        if (useAttr) {
            request.withMessageAttributeNames("All");
        }
        ReceiveMessageResult resp = amazonSQS.receiveMessage(request);
        if (CollectionUtils.isEmpty(resp.getMessages())) {
            return null;
        }
        Message msg = resp.getMessages().get(0);
        SqsObject obj = null;

        try {
            obj = messageObjectConverter.message2Obj(msg);
        } catch (IllegalMessageException e) {
            log.warn("illegal message, {}, {}", msg.getBody(), msg.getReceiptHandle());
            illegalMessageHandler.handle(amazonSQS, queueUrl, msg);
            return null;
        }
        obj.setSqsMessageId(msg.getMessageId());
        obj.setSqsReceiptHandle(msg.getReceiptHandle());

        putObj(obj);

        messageObjectConverter.log("receive", obj);

        return obj;
    }

    /**
     * 确认消息正常处理
     * @param obj 实体对象
     */
    public void ack(SqsObject obj) {

        amazonSQS.deleteMessage(queueUrl, obj.getSqsReceiptHandle());
        cleanObj(obj);
        messageObjectConverter.log("ack", obj);
    }

    // 监听器，用于自动扩展消息可见性的，每次增加一个单位的timeout时间
    protected void listen() {
        new Thread(() -> {
            log.info("start message visibility listener");
            while (true) {
                for (Map.Entry<String, ObjTime> entry : objTimeMap.entrySet()) {
                    ObjTime time = entry.getValue();

                    if ((System.currentTimeMillis()
                            - time.getTime()) / 1000
                            + updateVisibilityInterval > timeout) {
                        // 重新放入obj对象信息
                        time.setTime(System.currentTimeMillis());
                        amazonSQS.changeMessageVisibility(queueUrl,
                                time.getObj().getSqsReceiptHandle(),
                                timeout);
                        messageObjectConverter.log("visibility", time.getObj());
                    }
                }

                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
        }).start();
    }

    private String md5ReceiptHandle(String receiptHandle) {
        return Base32.encodeAsString(Md5Utils.computeMD5Hash(receiptHandle.getBytes()));
    }

    private void putObj(SqsObject obj) {
        String key = md5ReceiptHandle(obj.getSqsReceiptHandle());
        objTimeMap.put(key, new ObjTime(obj));
    }

    private void cleanObj(SqsObject obj) {
        String key = md5ReceiptHandle(obj.getSqsReceiptHandle());
        objTimeMap.remove(key);
    }

    @Override
    public void afterPropertiesSet() throws Exception {
        try {
            queueUrl = amazonSQS.getQueueUrl(queueName).getQueueUrl();
        } catch (Exception e) {
            if (errorContinue) {
                log.warn("skip queue url error", e);
            } else {
                throw e;
            }
        }
        useAttr = messageObjectConverter.useAttr();
        objTimeMap = new ConcurrentHashMap<>();
        if (illegalMessageHandler == null) {
            illegalMessageHandler = new DropIllegalMessageHandlerImpl();
        }
        listen();
    }

    public AmazonSQS getAmazonSQS() {
        return amazonSQS;
    }

    public void setAmazonSQS(AmazonSQS amazonSQS) {
        this.amazonSQS = amazonSQS;
    }

    public MessageObjectConverter getMessageObjectConverter() {
        return messageObjectConverter;
    }

    public void setMessageObjectConverter(MessageObjectConverter messageObjectConverter) {
        this.messageObjectConverter = messageObjectConverter;
    }

    public int getTimeout() {
        return timeout;
    }

    public void setTimeout(int timeout) {
        this.timeout = timeout;
    }

    public boolean getErrorContinue() {
        return errorContinue;
    }

    public void setErrorContinue(boolean errorContinue) {
        this.errorContinue = errorContinue;
    }

    public void setQueueName(String queueName) {
        this.queueName = queueName;
    }

    public String getQueueName() {
        return queueName;
    }

    public int getUpdateVisibilityInterval() {
        return updateVisibilityInterval;
    }

    public void setUpdateVisibilityInterval(int updateVisibilityInterval) {
        this.updateVisibilityInterval = updateVisibilityInterval;
    }

    @Getter
    @Setter
    private static class ObjTime {
        private SqsObject obj;

        private long time;
        public ObjTime(SqsObject obj) {
            this.obj = obj;
            this.time = System.currentTimeMillis();
        }

    }
}
